# -*- coding: utf-8 -*-
"""
DPO（专家/公众/融合）训练曲线图：
- 输出 PDF（可编辑文字，而不是 Type 3）
- 使用 TrueType 字体（Arial），可在 AI/Inkscape 调整字体字号
- 统一字体大小可调
"""

import os
import glob
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator

# ============================================================
# ★★★ 关键：让 PDF 中的文字可编辑（Type 42，不是 Type 3） ★★★
# ============================================================
matplotlib.rcParams['pdf.fonttype'] = 42   # TrueType，AI 可编辑
matplotlib.rcParams['ps.fonttype'] = 42    # 如果输出 EPS 也不变成路径

# ============================================================
# ★★★ 指定使用系统字体（必须是 TrueType） ★★★
# ============================================================
plt.rcParams['font.family'] = 'Arial'   # 推荐：Arial / Times New Roman / Helvetica
# plt.rcParams['font.family'] = 'SimHei'  # 如果你想用中文字体，可以改成黑体（注意必须安装）

# ============================================================
# 全局可调参数
# ============================================================
EMA_ALPHA = 0.2
FIG_SIZE = (9, 5)
DPI = 300
FONT_SIZE = 16

# 字体大小统一控制
plt.rcParams.update({
    "font.size": FONT_SIZE,
    "axes.labelsize": FONT_SIZE,
    "axes.titlesize": FONT_SIZE + 2,
    "legend.fontsize": FONT_SIZE - 2,
})

# ========== 工具函数 ==========
def ema_smooth(values, alpha=EMA_ALPHA):
    if not values:
        return []
    out = [values[0]]
    for v in values[1:]:
        out.append(alpha * v + (1 - alpha) * out[-1])
    return out


def prepare_series(steps, values):
    if len(steps) == 0 or len(values) == 0:
        return [], []
    df = pd.DataFrame({"step": steps, "val": values}).dropna()
    if df.empty:
        return [], []
    df = df.groupby("step", as_index=False)["val"].mean().sort_values("step")
    return df["step"].tolist(), df["val"].tolist()


def plot_raw_and_ema(steps, values, title, ylabel, save_path, alpha=EMA_ALPHA):
    steps, values = prepare_series(steps, values)
    if not values:
        print(f"[WARN] {title}: 无数据可画")
        return

    smooth = ema_smooth(values, alpha)

    plt.figure(figsize=FIG_SIZE)
    plt.plot(steps, values, color="gray", linewidth=1.0, alpha=0.6, label="Raw")
    plt.plot(steps, smooth, color="black", linewidth=2.0, label=f"Smoothed (EMA, α={alpha})")

    plt.xlabel("Training Step")
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.tight_layout()

    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # ============================================================
    # ★ 输出 PDF（可编辑文字） ★
    # ============================================================
    plt.savefig(save_path, dpi=DPI, format="pdf")
    plt.close()


# ========== CSV 读取 ==========
def find_col(df, prefer_exact, fallback_contains):
    lower_map = {c.lower(): c for c in df.columns}
    for key in prefer_exact:
        if key.lower() in lower_map:
            return lower_map[key.lower()]
    for c in df.columns:
        if any(tok in c.lower() for tok in fallback_contains):
            return c
    return None


def read_csv_series(csv_path):
    try:
        df = pd.read_csv(csv_path)
    except UnicodeDecodeError:
        df = pd.read_csv(csv_path, encoding="gbk")

    step_col = find_col(df, ["step", "training_step", "global_step"], ["step"])
    steps = df[step_col].tolist() if step_col else list(range(len(df)))

    loss_col = find_col(df, ["loss", "training_loss"], ["loss"])
    acc_col = find_col(df, ["pairwise_accuracy", "accuracy"], ["acc", "pairwise"])

    loss_vals = pd.to_numeric(df[loss_col], errors="coerce").dropna().tolist() if loss_col else []
    acc_vals = pd.to_numeric(df[acc_col], errors="coerce").dropna().tolist() if acc_col else []

    return (steps[:len(loss_vals)], loss_vals), (steps[:len(acc_vals)], acc_vals)


# ========== TensorBoard event ==========
def latest_event_file(path):
    if os.path.isdir(path):
        files = sorted(glob.glob(os.path.join(path, "events.out.tfevents*")))
        return files[-1] if files else None
    return path if os.path.isfile(path) else None


def read_event_series(path):
    event_file = latest_event_file(path)
    if not event_file:
        return ([], []), ([], [])

    ea = event_accumulator.EventAccumulator(event_file, size_guidance={event_accumulator.SCALARS: 0})
    ea.Reload()

    tags = ea.Tags().get("scalars", [])

    def pick(exact_list, contains_list):
        for e in exact_list:
            for t in tags:
                if t.lower() == e.lower():
                    return t
        for t in tags:
            if any(k in t.lower() for k in contains_list):
                return t
        return None

    loss_tag = pick(["loss", "train_loss"], ["loss"])
    acc_tag = pick(["pairwise_accuracy", "accuracy"], ["acc", "pairwise"])

    if loss_tag:
        es = ea.Scalars(loss_tag)
        loss_steps = [e.step for e in es]
        loss_vals = [float(e.value) for e in es]
    else:
        loss_steps, loss_vals = [], []

    if acc_tag:
        es = ea.Scalars(acc_tag)
        acc_steps = [e.step for e in es]
        acc_vals = [float(e.value) for e in es]
    else:
        acc_steps, acc_vals = [], []

    return (loss_steps, loss_vals), (acc_steps, acc_vals)


# ========== 主流程 ==========
def process_path(name, path, out_dir, alpha=EMA_ALPHA):
    is_csv = os.path.isfile(path) and path.lower().endswith(".csv")

    if is_csv:
        (loss_s, loss_v), (acc_s, acc_v) = read_csv_series(path)
    else:
        (loss_s, loss_v), (acc_s, acc_v) = read_event_series(path)

    if loss_v:
        plot_raw_and_ema(
            loss_s, loss_v,
            f"{name} — Training Loss (Raw + EMA)",
            "Loss",
            os.path.join(out_dir, f"{name}_loss_curve.pdf"),
            alpha
        )

    if acc_v:
        plot_raw_and_ema(
            acc_s, acc_v,
            f"{name} — Pairwise Accuracy (Raw + EMA)",
            "Pairwise Accuracy",
            os.path.join(out_dir, f"{name}_accuracy_curve.pdf"),
            alpha
        )


# ============================================================
# 执行入口
# ============================================================
if __name__ == "__main__":

    inputs = {
        "DPO_Expert": r"D:\chongxinxuexi\pythonProject1\dpo_model\expert\日志\dpo_training_log.csv",
        "DPO_Public": r"D:\chongxinxuexi\pythonProject1\dpo_model\public\tb",
        "DPO_Fusion": r"D:\chongxinxuexi\pythonProject1\dpo_model\fusion_from_expert\tb",
    }

    out_dir = r"D:\chongxinxuexi\pythonProject1\dpo_model\results_pdf"
    os.makedirs(out_dir, exist_ok=True)

    for name, path in inputs.items():
        print(f"==> Processing: {name}")
        process_path(name, path, out_dir)

    print("\n🎉 PDF 输出完成！所有文字现在都可以在 AI / Inkscape 中随意编辑！")
